import torch
from einops import rearrange

from pdbench.third_parties.esm.utils import residue_constants
from pdbench.third_parties.esm.utils.misc import unbinpack
from pdbench.third_parties.esm.utils.structure.protein_structure import (
    compute_alignment_tensors,
    compute_gdt_ts_no_alignment,
)


def compute_lddt(
    all_atom_pred_pos: torch.Tensor,
    all_atom_positions: torch.Tensor,
    all_atom_mask: torch.Tensor,
    cutoff: float = 15.0,
    eps: float = 1e-10,
    per_residue: bool = True,
    sequence_id: torch.Tensor | None = None,
) -> torch.Tensor:
    """
    Computes LDDT for a protein. Tensor sizes below include some optional dimensions. Specifically:
        Nstates:
            all_atom_pred_pos can contain multiple states in the first dimension which corresponds to outputs from different layers of a model (e.g. each IPA block). The return size will be [Nstates x Batch size] if this is included.
        Natoms:
            LDDT can be computed for all atoms or some atoms. The second to last dimension should contain the *FLATTENED* representation of L x Natoms. If you want to calculate for atom37, e.g., this will be of size (L * 37). If you are only calculating CA LDDT, it will be of size L.

    Args:
        all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions
        all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions
        all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists.
        cutoff (float): Max distance to score lddt over.
        per_residue (bool): Whether to return per-residue or full-protein lddt.
        sequence_id (Tensor, optional): Sequence id tensor for binpacking. NOTE: only supported for lddt_ca calculations, not when Natoms is passed!

    Returns:
        LDDT Tensor:
            if per_residue:
                Tensor[float], [(Nstates x) B x (L * Natoms)]
            else:
                Tensor[float], [(Nstates x) B]
    """
    n = all_atom_mask.shape[-2]
    dmat_true = torch.sqrt(
        eps
        + torch.sum(
            (all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :])
            ** 2,
            dim=-1,
        )
    )

    dmat_pred = torch.sqrt(
        eps
        + torch.sum(
            (all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2,
            dim=-1,
        )
    )
    dists_to_score = (
        (dmat_true < cutoff)
        * all_atom_mask
        * rearrange(all_atom_mask, "... a b -> ... b a")
        * (1.0 - torch.eye(n, device=all_atom_mask.device))
    )

    if sequence_id is not None:
        # TODO(roshan): This will work for lddt_ca, but not for regular lddt
        # Problem is that regular lddt has natoms * nres scores, so would need to repeat this mask by natoms
        # Leaving for now because it won't fail silently so should be ook.
        seqid_mask = sequence_id[..., None] == sequence_id[..., None, :]
        dists_to_score = dists_to_score * seqid_mask.type_as(dists_to_score)

    dist_l1 = torch.abs(dmat_true - dmat_pred)

    score = (
        (dist_l1 < 0.5).type(dist_l1.dtype)
        + (dist_l1 < 1.0).type(dist_l1.dtype)
        + (dist_l1 < 2.0).type(dist_l1.dtype)
        + (dist_l1 < 4.0).type(dist_l1.dtype)
    )
    score = score * 0.25

    dims = (-1,) if per_residue else (-2, -1)
    norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
    score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))

    return score


def compute_lddt_ca(
    all_atom_pred_pos: torch.Tensor,
    all_atom_positions: torch.Tensor,
    all_atom_mask: torch.Tensor,
    cutoff: float = 15.0,
    eps: float = 1e-10,
    per_residue: bool = True,
    sequence_id: torch.Tensor | None = None,
) -> torch.Tensor:
    ca_pos = residue_constants.atom_order["CA"]
    if all_atom_pred_pos.dim() != 3:
        all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
    all_atom_positions = all_atom_positions[..., ca_pos, :]
    all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)]  # keep dim

    return compute_lddt(
        all_atom_pred_pos,
        all_atom_positions,
        all_atom_mask,
        cutoff=cutoff,
        eps=eps,
        per_residue=per_residue,
        sequence_id=sequence_id,
    )


def compute_gdt_ts(
    mobile: torch.Tensor,
    target: torch.Tensor,
    atom_exists_mask: torch.Tensor | None = None,
    sequence_id: torch.Tensor | None = None,
    reduction: str = "per_sample",
):
    """
    Compute GDT_TS between two batches of structures with support for masking invalid atoms using PyTorch.

    Args:
    - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
    - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
    - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
    - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
    - reduction (str): One of "batch", "per_sample", "per_residue".

    Returns:
    If reduction == "batch":
        (torch.Tensor): 0-dim, GDT_TS between the structures for each batch
    If reduction == "per_sample":
        (torch.Tensor): (B,)-dim, GDT_TS between the structures for each sample in the batch
    """
    if atom_exists_mask is None:
        atom_exists_mask = torch.isfinite(target).all(dim=-1)
    (centered_mobile, _, centered_target, _, rotation_matrix, _) = (
        compute_alignment_tensors(
            mobile=mobile,
            target=target,
            atom_exists_mask=atom_exists_mask,
            sequence_id=sequence_id,
        )
    )

    # Apply transformation to centered structure
    rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)

    # the coordinate tensors returned by `compute_alignment_tensors` are unbinpacked and contain zeros for invalid positions
    # so `compute_gdt_ts_no_alignment` requires `atom_exists_mask` to be passed and be unbinpacked
    if sequence_id is not None:
        atom_exists_mask = unbinpack(atom_exists_mask, sequence_id, pad_value=False)
    return compute_gdt_ts_no_alignment(
        rotated_mobile, centered_target, atom_exists_mask, reduction
    )
